import requests
import functools
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
import os
import tiktoken
from datetime import datetime, timedelta
import subprocess
from poe_api_wrapper import PoeApi
import openai
import random
import base64
import codecs
import re

# Function to encode the image
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    print(image_path)
    return base64.b64encode(image_file.read()).decode('utf-8')
def decode_unicode_escapes(text):
    def replace_escape(match):
        return codecs.decode(match.group(0), 'unicode_escape')
    
    # 替换形式为 \uXXXX 的 Unicode 转义序列
    unicode_escape_pattern = re.compile(r'\\u[0-9a-fA-F]{4}')
    decoded_text = unicode_escape_pattern.sub(replace_escape, text)
    return decoded_text


from core import setting, mgClient
import pymongo as mg
db = mgClient['chat']
traffic_collection = db.get_collection('traiffic')
traffic_collection_4 = db.get_collection('traiffic_4')
traffic_collection_poe = db.get_collection('traiffic_poe')
user_collection = db.get_collection('user')

async def async_request_gemini_vision(prompt):
    
    parts = []
    parts.append({"text": prompt["content"]}),
    if "filename" not in prompt:
        
        keys = prompt['assets']
        payload = {"keys":keys}
        url = "https://user.chatcns.com/asset/list"
        loop = asyncio.get_running_loop()
        request = functools.partial(requests.post, url,json=payload, timeout = 30)
        try:
            with ThreadPoolExecutor() as executor:
                response = await loop.run_in_executor(
                    executor, request
                )
        except:
            response = requests.Response()
            response.status_code = 500
        if response.status_code == 200:
            lists = response.json()['list']
            for item in lists:
                url = item['snapshot']
                loop = asyncio.get_running_loop()
                request = functools.partial(requests.get, url, timeout = 30)
                with ThreadPoolExecutor() as executor:
                    response = await loop.run_in_executor(
                        executor, request
                    )
                encoded_image = base64.b64encode(response.content).decode('utf-8')
                parts.append(
                {
                "inline_data": {
                    "mime_type": "image/jpeg",
                    "data": encoded_image
                }
                }
                )
                    
    else:
        filename = prompt['filename']
        filepath = os.path.join('/var/www/html/images', f'{filename}.jpg')
        if os.path.exists(filepath):
            print(filepath)
            print('exist!')
            base64_image = encode_image(filepath)
            
            parts.append(
                {
                "inline_data": {
                    "mime_type": "image/jpeg",
                    "data": base64_image
                }
                }
            )
        
    payload = {
        "safetySettings":[

            {
                "category": "HARM_CATEGORY_HARASSMENT",
                "threshold": "BLOCK_NONE",
            },
            {
                "category": "HARM_CATEGORY_HATE_SPEECH",
                "threshold": "BLOCK_NONE",
            },
            {
                "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                "threshold": "BLOCK_NONE",
            },
            {
                "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                "threshold": "BLOCK_NONE",
            },
        ],
    "contents": [
        {
        "parts": parts
        }
    ]
    }
    print(payload)
    api_keys =[
        'yourkey'
    ]
    api_key = random.choice(api_keys)

    # Stream the content from the API endpoint
    url = f'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro-vision:streamGenerateContent?key={api_key}'
    headers = {'Content-Type': 'application/json'}
    
    loop = asyncio.get_running_loop()
    request = functools.partial(requests.post, url, headers=headers, json=payload, stream = True, timeout = 20)
    try:
        with ThreadPoolExecutor() as executor:
            response = await loop.run_in_executor(
                executor, request
            )
    except:
        response = requests.Response()
        response.status_code = 500
        
    return response
async def async_request_gemini(content):
    payload = {
        
            "safetySettings":[

                {
                    "category": "HARM_CATEGORY_HARASSMENT",
                    "threshold": "BLOCK_NONE",
                },
                {
                    "category": "HARM_CATEGORY_HATE_SPEECH",
                    "threshold": "BLOCK_NONE",
                },
                {
                    "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                    "threshold": "BLOCK_NONE",
                },
                {
                    "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                    "threshold": "BLOCK_NONE",
                },
            ],
    "contents": content
}

    api_keys =[
        'yourkey'
    ]
    api_key = random.choice(api_keys)

    # Stream the content from the API endpoint
    url = f'https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:streamGenerateContent?key={api_key}'
    headers = {'Content-Type': 'application/json'}
    
    loop = asyncio.get_running_loop()
    request = functools.partial(requests.post, url, headers=headers, json=payload, stream = True, timeout = 20)
    try:
        with ThreadPoolExecutor() as executor:
            response = await loop.run_in_executor(
                executor, request
            )
    except:
        print(response.text)
        response = requests.Response()
        response.status_code = 500
    return response
def generate_text_gemini(response):
    
    for line in response.iter_lines():
        # Filter out keep-alive new lines
        if line:
            # print(line)
            try:
                decoded_line = line.decode('utf-8')
                # Check if the line contains "text"
                if decoded_line.strip().startswith('"text"'):
                    text = decoded_line.strip()[9:-1]

                    text =  text.replace('\\n','\n').replace('\\t','\t').replace('\\"','\"')
                    text = decode_unicode_escapes(text)
                    yield text
            except:
                pass
def generate_text_glm(response,save_file_path=None,save_flag=0):
    
    for event in response.events():
        text = event.data
        # text = text.replace('\\n','\n').replace('\\t','\t')
        # text = text.replace('\\n','\n')
        yield text

def check_traffic_4(ip):
    """
    It checks the traffic of the given IP address.
    
    :param ip: The IP address of the host to check
    """
    now = datetime.now()
    traffic = traffic_collection_4.find_one({"ip": ip})  # 查找指定IP地址的流量信息
    if not traffic:
        traffic = {"ip": ip, "data": [{"timestamp": now, "count": 1}]}
        traffic_collection_4.insert_one(traffic)  # 插入新的流量信息
        return False
    else:
        recent_count = sum([t["count"] for t in traffic["data"]])
        # 删除超过24小时的旧流量数据
        traffic["data"] = [t for t in traffic["data"] if now - t["timestamp"] <= timedelta(hours=24)]
        if recent_count >= 1:  # 设置流量限制为10分钟内最多60个请求
            return True
        else:
            traffic["data"].append({"timestamp": now, "count": 1})
            traffic_collection_4.update_one({"ip": ip}, {"$set": traffic})  # 更新流量信息
            return False
def check_traffic(ip):
    """
    It checks the traffic of the given IP address.
    
    :param ip: The IP address of the host to check
    """
    now = datetime.now()
    traffic = traffic_collection.find_one({"ip": ip})  # 查找指定IP地址的流量信息
    if not traffic:
        traffic = {"ip": ip, "data": [{"timestamp": now, "count": 1}]}
        traffic_collection.insert_one(traffic)  # 插入新的流量信息
        return False
    else:
        recent_count = sum([t["count"] for t in traffic["data"]])
        # 删除超过24小时的旧流量数据
        traffic["data"] = [t for t in traffic["data"] if now - t["timestamp"] <= timedelta(hours=24)]
        if recent_count >= 400:  # 设置流量限制为10分钟内最多60个请求
            return True
        else:
            traffic["data"].append({"timestamp": now, "count": 1})
            traffic_collection.update_one({"ip": ip}, {"$set": traffic})  # 更新流量信息
            return False

def generate_error_user():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "没钱啦！"
    yield response

def num_tokens_from_message(messages, model="gpt-3.5-turbo-0301"):
    """
    It takes a list of messages and returns the number of tokens that would be used to encode them
    
    :param messages: a list of messages, where each message is a dictionary with keys "name" and
    "content"
    :param model: The model to use, defaults to gpt-3.5-turbo-0301 (optional)
    :return: The number of tokens used by a list of messages.
    """
    
    """Returns the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        encoding = tiktoken.get_encoding("cl100k_base")
    # if model == "gpt-3.5-turbo-0301":  # note: future models may deviate from this
    num_tokens = 0
    for message in messages:
        num_tokens += 4  # every message follows <im_start>{role/name}\n{content}<im_end>\n
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":  # if there's a name, the role is omitted
                num_tokens += -1  # role is always required and always 1 token
    num_tokens += 2  # every reply is primed with <im_start>assistant
    return num_tokens
    # else:
    #     raise NotImplementedError(f"""num_tokens_from_messages() is not presently implemented for model {model}.
# See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")

def consume_token(user_key,token_num):
    key_path = os.path.join('key_token',user_key)
    with open(key_path,'r') as f:
        token = int(f.read())
    token -= token_num
    if token < 0:
        # delete the key
        os.remove(key_path)
    else:
        with open(key_path,'w') as f:
            f.write(str(token))
    
def get_api_key(api_keys, api_key_index):
    """
    It returns the next API key in the list of API keys
    :return: The api_key is being returned.
    """
    print(api_key_index)
    api_key = api_keys[api_key_index]
    api_key_index = (api_key_index + 1) % len(api_keys)
    return api_key, api_key_index

def get_api_keys(key_file_path):
    """
    It opens the file at the path specified by the argument key_file_path, reads each line of the
    file, strips the newline character from each line, and returns a list of the lines
    
    :param key_file_path: The path to the file containing the API keys
    :return: A list of API keys
    """
    api_keys = []
    with open(key_file_path, 'r') as f:
        for line in f:
            api_keys.append(line.strip())
    return api_keys

def generate_error_key():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    #response = "每个IP每24小时只能使用1次4.0，您必须申请一个有效key才能继续免费使用本服务，具体方式请访问微信链接：https://mp.weixin.qq.com/s/1ZbctH6Iaa7LZkguW6ZnVA"
    response = "由于流量过大，本站限制了4.0的免费供应，请耐心等候，如您需要更多次数，请您访问以下网址：[https://jufahuo.com/store/t3XKJV](https://jufahuo.com/store/t3XKJV)，或者直接微信与我联系：remifafamirefa。 购买前请务必仔细阅读以下使用说明：[https://gitee.com/xu-zhanwei/chatanywhere/blob/master/4.md](https://gitee.com/xu-zhanwei/chatanywhere/blob/master/4.md)"
    yield response
def generate_error_traffic():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "我们聊的太多了，休息一会儿再来吧。"
    yield response
def generate_error():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "出了点小状况，大概率是开发者太穷租不起大流量，或者调用的服务有问题，可以再试一次吗？"
    yield response
def generate_error_path():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "图片已被清理，请重新上传。"
    yield response
def generate_error_max():
    """
    It generates an error
    """
    # This generator yields a single JSON response
    response = "对话内容太多了，您最好新开一个会话再试一次或者换用其他模型。"
    yield response

# async def async_request_glm(prompt: str):
#     print(prompt)
#     response = zhipusdk.model_api.sse_invoke(
#         model="chatglm_130b",
#         prompt=prompt,
#         temperature=0.95,
#         top_p=0.7,
#         incremental=True
#     )
#     print(response)
#     return response

async def async_request(prompt: str, api_key: str):
    
    data = {
        "model": "gpt-3.5-turbo-16k",
        "messages": prompt,
        "temperature": 1,
        # "max_tokens": 4096,
        "stream": True,  # Enable streaming API
    }
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer "+api_key,
    }
    url = "https://api.openai.com/v1/chat/completions"
    loop = asyncio.get_running_loop()
    request = functools.partial(requests.post, url, headers=headers, json=data, stream = True, timeout = 30)
    try:
        with ThreadPoolExecutor() as executor:
            response = await loop.run_in_executor(
                executor, request
            )
    except:
        response = requests.Response()
        response.status_code = 500
    return response

async def async_request_4(prompts, url = "https://api.openai.com/v1/chat/completions",key_value = "Bearer yourkey", model = "gpt-4-vision-preview"):
    key_values =[
        'yourkey'
    ]
    key_value = random.choice(key_values)
    messages = []
    for index in range (len(prompts)):
        prompt = prompts[index]
        message = {"role":prompt['role'],"content":[{"type":"text","text":prompt['content']}]}
        if 'filename' in prompt:
            filename = prompt['filename']
            filepath = os.path.join('/var/www/html/images', f'{filename}.jpg')
            if os.path.exists(filepath):
                base64_image = encode_image(filepath)
                message['content'].append({"type":"image_url","image_url":{"url": f"data:image/jpeg;base64,{base64_image}"}})
        if 'assets' in prompt:
            keys = prompt['assets']
            payload = {"keys":keys}
            url = "https://user.chatcns.com/asset/list"
            loop = asyncio.get_running_loop()
            request = functools.partial(requests.post, url, json=payload, timeout = 30)
            try:
                with ThreadPoolExecutor() as executor:
                    response = await loop.run_in_executor(
                        executor, request
                    )
            except:
                response = requests.Response()
                response.status_code = 500
            if response.status_code == 200:
                lists = response.json()['list']
                for item in lists:
                    url_ = item['snapshot']
                    message['content'].append({"type":"image_url","image_url":{"url": url_}})

        messages.append(message)
    # prompt = prompts[-1]
    # message = {"role":prompt['role'],"content":[{"type":"text","text":prompt['content']},{"type":"image_url","image_url":{"url": f"data:image/jpeg;base64,{base64_image}"}}]}
    # messages.append(message)
    # print(messages)
    data = {
        "model": model,
        "messages": messages,
        "max_tokens": 3000,
        "stream": True,  # Enable streaming API
    }
    headers = {
        "Content-Type": "application/json",
        "Authorization": key_value,
    }
    url = "https://api.openai.com/v1/chat/completions"
    loop = asyncio.get_running_loop()
    request = functools.partial(requests.post, url, headers=headers, json=data, stream = True, timeout = 30)
    try:
        with ThreadPoolExecutor() as executor:
            response = await loop.run_in_executor(
                executor, request
            )
    except:
        # 返回一个空的response，状态码为404
        print(response.status_code)
        response = requests.Response()
        response.status_code = 500
    return response


def get_save_file_path(client_ip,message):
    """
    If the client_ip directory doesn't exist, create it. Then, find the latest file in the
    directory. If there are no files, create a file named 1.txt. If there are files, and the message
    is longer than 2 characters, write to the latest file. If the message is 2 characters or less,
    create a new file with the name of the latest file + 1
    
    :param client_ip: the IP address of the client
    :param message: the message sent by the client
    :return: The file name of the file to be written to.
    """
    
    # check if the clien_ip dir is exist, if not, create it
    if not os.path.exists(os.path.join('log',client_ip)):
        os.makedirs(os.path.join('log',client_ip))
    # find the latest file in the dir
    file_list = os.listdir(os.path.join('log',client_ip))
    file_list.sort(key=lambda fn: os.path.getmtime(os.path.join('log',client_ip, fn)))
    if len(file_list) == 0:
        file_name = os.path.join('log',client_ip,'1.txt')
    else:
        file_name = os.path.join('log',client_ip,str(int(file_list[-1].split('.')[0])+1)+'.txt')
            
    # write the message in the json to the file
    with open(file_name, 'a') as f:
        f.write(str(message[-1]))
        # for i in range(len(message)):
        #     f.write(str(message[i])+'\n')
    # write the response in the json to the file
    # with open(file_name, 'a') as f:
    #     write_content = "{\'role\': \'assistant\', \'content\': \'" 
    #     f.write(write_content)
    return file_name

def generate_text_glm(response,save_file_path=None,save_flag=0):
    
    for event in response.events():
        text = event.data
        # text = text.replace('\\n','\n').replace('\\t','\t')
        # text = text.replace('\\n','\n')
        yield text
def generate_text_openai(response):
    
    for response_dict in response:
        if "choices" in response_dict and response_dict["choices"]:
            text = response_dict["choices"][0]["delta"].get("content")
            if text:
                text = text.replace('\\n','\n').replace('\\t','\t')
                text = text.replace('\\n','\n')
                yield text

def generate_text(response,user_key=None):
    """
    It takes the response, and then it iterates through the response,
    and it finds the text that is in the response, and then it
    yields it
    
    :param response: the response from the server
    :param summary_flag: if it's 1, then the function will return a summary of the conversation
    before the answer, defaults to 0 (optional)
    """
    all_text= []
    for event in response.iter_content(chunk_size=None):
        response_str = event.decode('utf-8')
        strs = response_str.split('data: ')
        for str in strs:
            try:
                response_dict = json.loads(str)
                if "choices" in response_dict and response_dict["choices"]:
                    text = response_dict["choices"][0]["delta"].get("content")
                    if text:
                        all_text.append(text)
                        # text = re.sub(r'\\n', '\n', text)
                        # text = re.sub(r'\\t', '\t', text)
                        yield text
            except:
                pass
    if user_key:
        all_text = ''.join(all_text)
        message = [{"role":"assistant","content":all_text}]
        token_num = num_tokens_from_message(message,model="gpt-4")
        consume_token(user_key,token_num)
        print('adadfsaf')
    
